import torch
from torch.nn import Parameter
from torch.nn.init import xavier_normal_
# from layers.message_passing import MessagePassing
from torch_geometric.nn import MessagePassing
from torch_scatter import scatter_add

def get_param(shape):
	param = Parameter(torch.Tensor(*shape)); 	
	xavier_normal_(param.data)
	return param

class CompGCNConv(MessagePassing):
	def __init__(self, in_channels, out_channels, rel_embed_size=None):
		super(self.__class__, self).__init__()

		self.in_channels	= in_channels
		self.out_channels	= out_channels
		if rel_embed_size is None:
			rel_embed_size = in_channels
		# self.num_rels 		= num_rels
		self.act 		= torch.tanh
		self.device		= None

		self.w_loop		= get_param((in_channels, out_channels))
		self.w_edge		= get_param((in_channels, out_channels))
		self.w_rel 		= get_param((rel_embed_size, in_channels))
		self.loop_rel 		= get_param((1, in_channels))

		# self.drop		= torch.nn.Dropout(0.1)
		# self.bn			= torch.nn.BatchNorm1d(out_channels)

	def forward(self, x, edge_index, edge_attr): 
		if self.device is None:
			self.device = edge_index.device

		# rel_embed = torch.cat([rel_embed, self.loop_rel], dim=0)
		num_edges = edge_index.size(1)
		num_ent   = x.size(0)
		self.edge_index = edge_index

		if num_ent <= self.edge_index.max().item():
			import pdb; pdb.set_trace()
		self.loop_index  = torch.stack([torch.arange(num_ent), torch.arange(num_ent)]).to(self.device)
		# self.loop_type   = torch.full((num_ent,), rel_embed.size(0)-1, dtype=torch.long).to(self.device)
		loop_embed = self.loop_rel.expand(num_ent, -1)

		# self.edge_norm     = self.compute_norm(self.edge_index, num_ent)
		self.edge_norm = None
		
		rel_embed = torch.matmul(edge_attr, self.w_rel)
		edge_res		= self.propagate(self.edge_index, x=x, rel_embed=rel_embed, edge_norm=self.edge_norm, mode='edge')
		loop_res	= self.propagate(self.loop_index, x=x, rel_embed=loop_embed, edge_norm=None, mode='loop')
		out		= edge_res*(1/2) + loop_res*(1/2)
		# out = edge_res

		# out = self.bn(out)

		# return self.act(out), torch.matmul(rel_embed, self.w_rel)[:-1]		# Ignoring the self loop inserted
		# return self.act(out), rel_embed[:-1]
		return out

	def rel_transform(self, ent_embed, rel_embed):
		# if   self.p.opn == 'corr': 	trans_embed  = ccorr(ent_embed, rel_embed)
		# elif self.p.opn == 'sub': 	trans_embed  = ent_embed - rel_embed
		# elif self.p.opn == 'mult': 	trans_embed  = ent_embed * rel_embed
		# else: raise NotImplementedError
		trans_embed  = ent_embed - rel_embed

		return trans_embed

	def message(self, x_j, rel_embed, edge_norm, mode):
		weight 	= getattr(self, 'w_{}'.format(mode))
		xj_rel  = self.rel_transform(x_j, rel_embed)
		out	= torch.mm(xj_rel, weight)

		return out if edge_norm is None else out * edge_norm.view(-1, 1)

	def update(self, aggr_out):
		return aggr_out

	def compute_norm(self, edge_index, num_ent):
		row, col	= edge_index
		edge_weight 	= torch.ones_like(row).float()
		deg		= scatter_add( edge_weight, row, dim=0, dim_size=num_ent)	# Summing number of weights of the edges
		deg_inv		= deg.pow(-0.5)							# D^{-0.5}
		deg_inv[deg_inv	== float('inf')] = 0
		norm		= deg_inv[row] * edge_weight * deg_inv[col]			# D^{-0.5}

		return norm

	def __repr__(self):
		return '{}({}, {}, num_rels={})'.format(
			self.__class__.__name__, self.in_channels, self.out_channels, self.num_rels)